# # Adapted from
# # https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
# """Inference-only Mixtral model."""
# from typing import List, Optional, Tuple

# import numpy as np
# import torch
# import torch.nn.functional as F
# from sglang.srt.layers.logits_processor import LogitsProcessor
# from sglang.srt.layers.radix_attention import RadixAttention
# from sglang.srt.managers.router.model_runner import InputMetadata
# from torch import nn
# from transformers import MixtralConfig
# from vllm.model_executor.layers.layernorm import RMSNorm
# from vllm.model_executor.layers.linear import (
#     LinearMethodBase,
#     QKVParallelLinear,
#     ReplicatedLinear,
#     RowParallelLinear,
# )
# from vllm.model_executor.layers.rotary_embedding import get_rope
# from vllm.model_executor.layers.vocab_parallel_embedding import (
#     ParallelLMHead,
#     VocabParallelEmbedding,
# )
# from vllm.model_executor.parallel_utils.communication_op import (
#     tensor_model_parallel_all_reduce,
# )
# from vllm.model_executor.parallel_utils.parallel_state import (
#     get_tensor_model_parallel_rank,
#     get_tensor_model_parallel_world_size,
# )
# from vllm.model_executor.weight_utils import (
#     default_weight_loader,
#     hf_model_weights_iterator,
# )


# class MixtralMLP(nn.Module):
#     def __init__(
#         self,
#         num_experts: int,
#         hidden_size: int,
#         intermediate_size: int,
#         linear_method: Optional[LinearMethodBase] = None,
#     ) -> None:
#         super().__init__()
#         self.num_experts = num_experts
#         self.ffn_dim = intermediate_size
#         self.hidden_dim = hidden_size

#         self.w1 = ReplicatedLinear(
#             self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
#         )
#         self.w2 = ReplicatedLinear(
#             self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method
#         )
#         self.w3 = ReplicatedLinear(
#             self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
#         )

#         # TODO: Use vllm's SiluAndMul
#         self.act_fn = nn.SiLU()

#     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
#         w1_out, _ = self.w1(hidden_states)
#         w1_out = self.act_fn(w1_out)
#         w3_out, _ = self.w3(hidden_states)
#         current_hidden_states = w1_out * w3_out
#         current_hidden_states, _ = self.w2(current_hidden_states)
#         return current_hidden_states


# class MixtralMoE(nn.Module):
#     def __init__(
#         self,
#         config: MixtralConfig,
#         linear_method: Optional[LinearMethodBase] = None,
#     ):
#         super().__init__()
#         self.config = config
#         self.rank = get_tensor_model_parallel_rank()
#         self.tp_size = get_tensor_model_parallel_world_size()
#         self.num_total_experts = config.num_local_experts
#         self.top_k = config.num_experts_per_tok
#         if self.tp_size > self.num_total_experts:
#             raise ValueError(
#                 f"Tensor parallel size {self.tp_size} is greater than "
#                 f"the number of experts {self.num_total_experts}."
#             )
#         # Split experts equally between ranks
#         self.expert_indicies = np.array_split(
#             range(self.num_total_experts), self.tp_size
#         )[self.rank].tolist()
#         if not self.expert_indicies:
#             raise ValueError(f"Rank {self.rank} has no experts assigned to it.")

#         self.experts = nn.ModuleList(
#             [
#                 (
#                     MixtralMLP(
#                         self.num_total_experts,
#                         config.hidden_size,
#                         config.intermediate_size,
#                         linear_method=linear_method,
#                     )
#                     if idx in self.expert_indicies
#                     else None
#                 )
#                 for idx in range(self.num_total_experts)
#             ]
#         )
#         self.gate = ReplicatedLinear(
#             config.hidden_size, self.num_total_experts, bias=False, linear_method=None
#         )

#     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
#         router_logits, _ = self.gate(hidden_states)

#         routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
#         routing_weights, selected_experts = torch.topk(
#             routing_weights, self.top_k, dim=-1
#         )
#         routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

#         final_hidden_states = None
#         for expert_idx in self.expert_indicies:
#             expert_layer = self.experts[expert_idx]
#             expert_mask = selected_experts == expert_idx
#             expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)

#             current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
#             if final_hidden_states is None:
#                 final_hidden_states = current_hidden_states
#             else:
#                 final_hidden_states.add_(current_hidden_states)

#         return tensor_model_parallel_all_reduce(final_hidden_states)


# class MixtralAttention(nn.Module):
#     def __init__(
#         self,
#         hidden_size: int,
#         num_heads: int,
#         num_kv_heads: int,
#         layer_id: int = 0,
#         max_position: int = 4096 * 32,
#         rope_theta: float = 10000,
#         linear_method: Optional[LinearMethodBase] = None,
#         sliding_window: Optional[int] = None,
#     ) -> None:
#         super().__init__()
#         self.hidden_size = hidden_size
#         tp_size = get_tensor_model_parallel_world_size()
#         self.total_num_heads = num_heads
#         assert self.total_num_heads % tp_size == 0
#         self.num_heads = self.total_num_heads // tp_size
#         self.total_num_kv_heads = num_kv_heads
#         if self.total_num_kv_heads >= tp_size:
#             # Number of KV heads is greater than TP size, so we partition
#             # the KV heads across multiple tensor parallel GPUs.
#             assert self.total_num_kv_heads % tp_size == 0
#         else:
#             # Number of KV heads is less than TP size, so we replicate
#             # the KV heads across multiple tensor parallel GPUs.
#             assert tp_size % self.total_num_kv_heads == 0
#         self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
#         self.head_dim = hidden_size // self.total_num_heads
#         self.q_size = self.num_heads * self.head_dim
#         self.kv_size = self.num_kv_heads * self.head_dim
#         self.scaling = self.head_dim**-0.5
#         self.rope_theta = rope_theta
#         self.sliding_window = sliding_window

#         self.qkv_proj = QKVParallelLinear(
#             hidden_size,
#             self.head_dim,
#             self.total_num_heads,
#             self.total_num_kv_heads,
#             bias=False,
#             linear_method=linear_method,
#         )
#         self.o_proj = RowParallelLinear(
#             self.total_num_heads * self.head_dim,
#             hidden_size,
#             bias=False,
#             linear_method=linear_method,
#         )
#         self.rotary_emb = get_rope(
#             self.head_dim,
#             rotary_dim=self.head_dim,
#             max_position=max_position,
#             base=int(self.rope_theta),
#             is_neox_style=True,
#         )
#         self.attn = RadixAttention(
#             self.num_heads,
#             self.head_dim,
#             self.scaling,
#             num_kv_heads=self.num_kv_heads,
#             layer_id=layer_id,
#         )

#     def forward(
#         self,
#         positions: torch.Tensor,
#         hidden_states: torch.Tensor,
#         input_metadata: InputMetadata,
#     ) -> torch.Tensor:
#         qkv, _ = self.qkv_proj(hidden_states)
#         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
#         q, k = self.rotary_emb(positions, q, k)
#         attn_output = self.attn(q, k, v, input_metadata)
#         output, _ = self.o_proj(attn_output)
#         return output


# class MixtralDecoderLayer(nn.Module):
#     def __init__(
#         self,
#         config: MixtralConfig,
#         layer_id: int = 0,
#         linear_method: Optional[LinearMethodBase] = None,
#     ) -> None:
#         super().__init__()
#         self.hidden_size = config.hidden_size
#         # Requires transformers > 4.32.0
#         rope_theta = getattr(config, "rope_theta", 10000)
#         self.self_attn = MixtralAttention(
#             hidden_size=self.hidden_size,
#             num_heads=config.num_attention_heads,
#             max_position=config.max_position_embeddings,
#             num_kv_heads=config.num_key_value_heads,
#             layer_id=layer_id,
#             rope_theta=rope_theta,
#             sliding_window=config.sliding_window,
#             linear_method=linear_method,
#         )
#         self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
#         self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
#         self.post_attention_layernorm = RMSNorm(
#             config.hidden_size, eps=config.rms_norm_eps
#         )

#     def forward(
#         self,
#         positions: torch.Tensor,
#         hidden_states: torch.Tensor,
#         input_metadata: InputMetadata,
#         residual: Optional[torch.Tensor],
#     ) -> torch.Tensor:
#         # Self Attention
#         if residual is None:
#             residual = hidden_states
#             hidden_states = self.input_layernorm(hidden_states)
#         else:
#             hidden_states, residual = self.input_layernorm(hidden_states, residual)
#         hidden_states = self.self_attn(
#             positions=positions,
#             hidden_states=hidden_states,
#             input_metadata=input_metadata,
#         )

#         # Fully Connected
#         hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
#         hidden_states = self.block_sparse_moe(hidden_states)
#         return hidden_states, residual


# class MixtralModel(nn.Module):
#     def __init__(
#         self,
#         config: MixtralConfig,
#         linear_method: Optional[LinearMethodBase] = None,
#     ) -> None:
#         super().__init__()
#         self.padding_idx = config.pad_token_id
#         self.vocab_size = config.vocab_size

#         self.embed_tokens = VocabParallelEmbedding(
#             config.vocab_size,
#             config.hidden_size,
#         )
#         # config.num_hidden_layers=16
#         self.layers = nn.ModuleList(
#             [
#                 MixtralDecoderLayer(config, i, linear_method=linear_method)
#                 for i in range(config.num_hidden_layers)
#             ]
#         )
#         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

#     def forward(
#         self,
#         input_ids: torch.Tensor,
#         positions: torch.Tensor,
#         input_metadata: InputMetadata,
#         input_embeds: torch.Tensor = None,
#     ) -> torch.Tensor:
#         if input_embeds is None:
#             hidden_states = self.embed_tokens(input_ids)
#         else:
#             hidden_states = input_embeds
#         residual = None
#         for i in range(len(self.layers)):
#             layer = self.layers[i]
#             hidden_states, residual = layer(
#                 positions, hidden_states, input_metadata, residual
#             )
#         hidden_states, _ = self.norm(hidden_states, residual)
#         return hidden_states


# class MixtralForCausalLM(nn.Module):
#     def __init__(
#         self,
#         config: MixtralConfig,
#         linear_method: Optional[LinearMethodBase] = None,
#     ) -> None:
#         super().__init__()
#         self.config = config
#         self.linear_method = linear_method
#         self.model = MixtralModel(config, linear_method)
#         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
#         self.logits_processor = LogitsProcessor(config)

#     def forward(
#         self,
#         input_ids: torch.Tensor,
#         positions: torch.Tensor,
#         input_metadata: InputMetadata,
#         input_embeds: torch.Tensor = None,
#     ) -> torch.Tensor:
#         hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
#         return self.logits_processor(
#             input_ids, hidden_states, self.lm_head.weight, input_metadata
#         )

#     def load_weights(
#         self,
#         model_name_or_path: str,
#         cache_dir: Optional[str] = None,
#         load_format: str = "auto",
#         revision: Optional[str] = None,
#     ):
#         stacked_params_mapping = [
#             # (param_name, shard_name, shard_id)
#             ("qkv_proj", "q_proj", "q"),
#             ("qkv_proj", "k_proj", "k"),
#             ("qkv_proj", "v_proj", "v"),
#         ]

#         params_dict = dict(self.named_parameters())
#         for name, loaded_weight in hf_model_weights_iterator(
#             model_name_or_path,
#             cache_dir,
#             load_format,
#             revision,
#             fall_back_to_pt=False,
#         ):
#             if "rotary_emb.inv_freq" in name:
#                 continue
#             for param_name, weight_name, shard_id in stacked_params_mapping:
#                 if weight_name not in name:
#                     continue
#                 name = name.replace(weight_name, param_name)
#                 # Skip loading extra bias for GPTQ models.
#                 if name.endswith(".bias") and name not in params_dict:
#                     continue
#                 param = params_dict[name]
#                 weight_loader = param.weight_loader
#                 weight_loader(param, loaded_weight, shard_id)
#                 break
#             else:
#                 # Skip loading extra bias for GPTQ models.
#                 if name.endswith(".bias") and name not in params_dict:
#                     continue
#                 # Skip experts that are not assigned to this worker.
#                 if "block_sparse_moe.experts." in name and name not in params_dict:
#                     continue
#                 param = params_dict[name]
#                 weight_loader = getattr(param, "weight_loader", default_weight_loader)
#                 weight_loader(param, loaded_weight)


# EntryClass = MixtralForCausalLM
